class Solution {
public:
    int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) 
    {
        int N = graph.size();
        vector<vector<int>>sources(N); // node -> vector of adjacent infection sources
        unordered_set<int>initials(initial.begin(),initial.end());
        
        vector<vector<int>>nextNodes(N);
        for (int i=0; i<N; i++)
            for (int j=0; j<N; j++)
                if (i!=j && graph[i][j]==1)
                {
                    nextNodes[i].push_back(j);
                    nextNodes[j].push_back(i);
                }
                    
        for (int start: initial)
        {
            queue<pair<int,int>>q;
            q.push({start,start});
            unordered_set<int>visited;
            visited.insert(start);
            
            while (!q.empty())
            {
                int cur = q.front().first;
                int source = q.front().second;
                q.pop();
            
                for (auto next: nextNodes[cur])
                {
                    if (visited.find(next)!=visited.end())
                        continue;
                    if (initials.find(next)!=initials.end())
                        continue;
                    
                    sources[next].push_back(source);
                    
                    q.push({next,source});
                    visited.insert(next);
                }
            }            
        }
        
        unordered_map<int,int>count; // source -> how many nodes can be reached
        for (auto x: sources)
        {
            if (x.size()==1)
            {
                int s = x[0];
                count[s] += 1;
            }
        }
        
        int size = 0;
        int ret = -1;
        for (auto x: count)
        {
            if (x.second > size)
            {
                ret = x.first;
                size = x.second;
            }
            else if (x.second == size && x.first < ret)
                ret = x.first;
        }
        
        if (ret==-1)
        {
            sort(initial.begin(),initial.end());
            return initial[0];
        }
        else
            return ret;
        
    }
};
